#include "MotionBlurShader.h"

MotionBlurShader::MotionBlurShader(ID3D11Device* device, HWND hwnd) : BaseShader(device, hwnd)
{
	initShader(L"postProcessing_vs.cso", L"motionBlur_ps.cso");
}


MotionBlurShader::~MotionBlurShader()
{
	//Release the buffers
	if (matrixBuffer)
	{
		matrixBuffer->Release();
		matrixBuffer = 0;
	}

	if (blurBuffer)
	{
		blurBuffer->Release();
		blurBuffer = 0;
	}

	//Release the samplers
	if (sampleState)
	{
		sampleState->Release();
		sampleState = 0;
	}

	if (layout)
	{
		layout->Release();
		layout = 0;
	}

	BaseShader::~BaseShader();
}

void MotionBlurShader::initShader(const wchar_t* vsFilename, const wchar_t* psFilename)
{
	loadTextureVertexShader(vsFilename);
	loadPixelShader(psFilename);

	//Setup the matrix buffer
	D3D11_BUFFER_DESC matrixBufferDesc;
	matrixBufferDesc.Usage = D3D11_USAGE_DYNAMIC;
	matrixBufferDesc.ByteWidth = sizeof(MatrixBufferType);
	matrixBufferDesc.BindFlags = D3D11_BIND_CONSTANT_BUFFER;
	matrixBufferDesc.CPUAccessFlags = D3D11_CPU_ACCESS_WRITE;
	matrixBufferDesc.MiscFlags = 0;
	matrixBufferDesc.StructureByteStride = 0;
	renderer->CreateBuffer(&matrixBufferDesc, NULL, &matrixBuffer);

	//Setup the blur buffer
	D3D11_BUFFER_DESC blurBufferDesc;
	blurBufferDesc.Usage = D3D11_USAGE_DYNAMIC;
	blurBufferDesc.ByteWidth = sizeof(BlurBufferType);
	blurBufferDesc.BindFlags = D3D11_BIND_CONSTANT_BUFFER;
	blurBufferDesc.CPUAccessFlags = D3D11_CPU_ACCESS_WRITE;
	blurBufferDesc.MiscFlags = 0;
	blurBufferDesc.StructureByteStride = 0;
	renderer->CreateBuffer(&blurBufferDesc, NULL, &blurBuffer);

	//Setup the sampler
	D3D11_SAMPLER_DESC samplerDesc;
	samplerDesc.Filter = D3D11_FILTER_ANISOTROPIC;
	samplerDesc.AddressU = D3D11_TEXTURE_ADDRESS_CLAMP;
	samplerDesc.AddressV = D3D11_TEXTURE_ADDRESS_CLAMP;
	samplerDesc.AddressW = D3D11_TEXTURE_ADDRESS_CLAMP;
	samplerDesc.MipLODBias = 0.0f;
	samplerDesc.MaxAnisotropy = 1;
	samplerDesc.ComparisonFunc = D3D11_COMPARISON_ALWAYS;
	samplerDesc.MinLOD = 0;
	samplerDesc.MaxLOD = D3D11_FLOAT32_MAX;
	renderer->CreateSamplerState(&samplerDesc, &sampleState);
}

void MotionBlurShader::setShaderParameters(ID3D11DeviceContext* deviceContext, const XMMATRIX &worldMatrix, const XMMATRIX &viewMatrix, const XMMATRIX &projectionMatrix, ID3D11ShaderResourceView* renderTexture, XMFLOAT2 direction)
{
	D3D11_MAPPED_SUBRESOURCE mappedResource;
	XMMATRIX tworld, tview, tproj;

	tworld = XMMatrixTranspose(worldMatrix);
	tview = XMMatrixTranspose(viewMatrix);
	tproj = XMMatrixTranspose(projectionMatrix);

	//Map the matrix buffer
	deviceContext->Map(matrixBuffer, 0, D3D11_MAP_WRITE_DISCARD, 0, &mappedResource);
	MatrixBufferType* matrixPtr = (MatrixBufferType*)mappedResource.pData;
	matrixPtr->world = tworld;
	matrixPtr->view = tview;
	matrixPtr->projection = tproj;
	deviceContext->Unmap(matrixBuffer, 0);
	deviceContext->VSSetConstantBuffers(0, 1, &matrixBuffer);

	//Map the blur buffer
	deviceContext->Map(blurBuffer, 0, D3D11_MAP_WRITE_DISCARD, 0, &mappedResource);
	BlurBufferType* blurPtr = (BlurBufferType*)mappedResource.pData;
	blurPtr->direction = direction;
	blurPtr->screenDimensions = XMFLOAT2(1200.0f, 675.0f);
	deviceContext->Unmap(blurBuffer, 0);
	deviceContext->PSSetConstantBuffers(0, 1, &blurBuffer);

	//Set the pixel shader textures and samplers
	deviceContext->PSSetShaderResources(0, 1, &renderTexture);
	deviceContext->PSSetSamplers(0, 1, &sampleState);
}